{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# op IR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import testing"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tvm\n",
    "from tvm import relay\n",
    "from tvm.relay.testing.temp_op_attr import TempOpAttr\n",
    "from tvm.relay.op import op as _op"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## op 属性"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "属性访问："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "log_op = relay.op.get(\"log\")\n",
    "assert log_op.num_inputs == 1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "注册 op 属性函数："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "@tvm.ir.register_op_attr(\"exp\", \"ftest\")\n",
    "def test(x):\n",
    "    return x + 1\n",
    "\n",
    "assert log_op.get_attr(\"ftest\") is None\n",
    "assert relay.op.get(\"exp\").get_attr(\"ftest\")(1) == 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "重置属性函数："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<function __main__.add2(x)>"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def add1(x):\n",
    "        return x + 1\n",
    "\n",
    "def add2(x):\n",
    "    return x + 2\n",
    "\n",
    "# 注册 fadd1 和 fadd2 属性\n",
    "tvm.ir.register_op_attr(\"exp\", \"fadd1\", add1)\n",
    "tvm.ir.register_op_attr(\"log\", \"fadd1\", add1)\n",
    "tvm.ir.register_op_attr(\"log\", \"fadd2\", add2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "重置 `log` 属性函数："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "log_op = relay.op.get(\"log\")\n",
    "log_op.reset_attr(\"fadd1\")\n",
    "# 检查 fadd1 属性是否已重置。\n",
    "assert log_op.get_attr(\"fadd1\") is None\n",
    "# 检查其他算子的 fadd1 属性是否完好无损。\n",
    "assert relay.op.get(\"exp\").get_attr(\"fadd1\")(1) == 2\n",
    "# 检查 log 算子的其他属性是否完好无损。\n",
    "assert relay.op.get(\"log\").get_attr(\"fadd2\")(1) == 3"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## op 临时属性"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def add1(x):\n",
    "    return x + 1\n",
    "\n",
    "def add2(x):\n",
    "    return x + 2\n",
    "\n",
    "# 将原始 attr 值设置为add1。\n",
    "tvm.ir.register_op_attr(\"sqrt\", \"ftest\", add1)\n",
    "\n",
    "with TempOpAttr(\"sqrt\", \"ftest\", add2):\n",
    "    # 检查 attr 值是否已更新为 add2。\n",
    "    assert relay.op.get(\"sqrt\").get_attr(\"ftest\")(1) == 3\n",
    "\n",
    "# 检查 attr 值是否已恢复为 add1。\n",
    "assert relay.op.get(\"sqrt\").get_attr(\"ftest\")(1) == 2"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## op 注册"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "op_name = \"custom_op\"\n",
    "\n",
    "_op.register(op_name, r\"code(Add two tensor with inner broadcasting.)code\")\n",
    "_op.get(op_name).set_num_inputs(2)\n",
    "_op.get(op_name).add_argument(\"data_0\", \"Tensor\", \"The input data tensor.\")\n",
    "_op.get(op_name).add_argument(\"data_1\", \"Tensor\", \"The input data tensor.\")\n",
    "# 调用默认关系函数\n",
    "_op.get(op_name).add_type_rel(\"Identity\")\n",
    "_op.get(op_name).set_support_level(1)\n",
    "_op.register_pattern(op_name, _op.OpPattern.ELEMWISE)\n",
    "_op.register_stateful(op_name, False)\n",
    "\n",
    "assert _op.get(op_name).name == op_name\n",
    "assert _op.get(op_name).num_inputs == 2\n",
    "assert _op.get(op_name).get_attr(\"TOpPattern\") == _op.OpPattern.ELEMWISE\n",
    "assert _op.get(op_name).get_attr(\"TOpIsStateful\") == False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```{note}\n",
    "`\"TOpIsStateful\"` 为 `True` 表示算子是有状态的或包含内部状态。\n",
    "\n",
    "我们总是可以通过添加额外的句柄参数并返回它来处理有状态的算子。\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[0;31mSignature:\u001b[0m \u001b[0m_op\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mregister\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mop_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdescribe\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m''\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mSource:\u001b[0m   \n",
      "\u001b[0;32mdef\u001b[0m \u001b[0mregister\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mop_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdescribe\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0;34m\"\"\"Get the Op for a given name.\u001b[0m\n",
      "\u001b[0;34m    when the op_name is not registered, create a new empty op with the given name.\u001b[0m\n",
      "\u001b[0;34m    when the op_name has been registered, abort with an error message.\u001b[0m\n",
      "\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m    Parameters\u001b[0m\n",
      "\u001b[0;34m    ----------\u001b[0m\n",
      "\u001b[0;34m    op_name : str\u001b[0m\n",
      "\u001b[0;34m        The operator name\u001b[0m\n",
      "\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m    describe : Optional[str]\u001b[0m\n",
      "\u001b[0;34m        The operator description\u001b[0m\n",
      "\u001b[0;34m    \"\"\"\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\n",
      "\u001b[0;34m\u001b[0m    \u001b[0mtvm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mir\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_ffi_api\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mRegisterOp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mop_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdescribe\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mFile:\u001b[0m      /media/pc/data/lxw/ai/tvm/python/tvm/relay/op/op.py\n",
      "\u001b[0;31mType:\u001b[0m      function"
     ]
    }
   ],
   "source": [
    "_op.register??"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tvmz",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
